#include <stdlib.h>
#include <math.h>
#include "../../Include/manager/log_manager.h"

/**
 * ReLU激励函数。当x大于0是，返回x；否则返回0。
 * @param x
 * @return
 */
double ReLU(double x) {
    if (x > 0) {
        return x;
    } else {
        return 0;
    }
}
/**************************************************************************************************/

/**
 * ReLU函数的数组参数的行数和列数
 */
int RELU_ARRAY_ROW;
int RELU_ARRAY_COLUMN;

/**
 * ReLU激励函数。当x大于0是，返回x；否则返回0。
 * @param x 数组
 * @return
 */
double **ReLUArray(double x[RELU_ARRAY_ROW][RELU_ARRAY_COLUMN]) {
    double (*result)[RELU_ARRAY_COLUMN] = (double *) malloc(sizeof(double) * RELU_ARRAY_ROW * RELU_ARRAY_COLUMN);
    for (int i = 0; i < RELU_ARRAY_ROW; i++) {
        for (int j = 0; j < RELU_ARRAY_COLUMN; ++j) {
            if (x[i] > 0) {
                result[i][j] = x[i][j];
            } else {
                result[i][j] = 0;
            }
        }
    }
    return result;
}
/**************************************************************************************************/

/**
 * 求均方误差时，数组参数的行数和列数
 */
int MSE_ROW;
int MSE_COLUMN;

/**
 * 求均方误差
 * @param testDataList
 * @param y
 * @return
 */
double **MeanSquareError(double (*trainLabelList)[1], double (*y)[MSE_COLUMN]){
    double (*meanSquareError)[MSE_COLUMN] = (double *) malloc(sizeof(double) * MSE_ROW * MSE_COLUMN);
    for (int i = 0; i < MSE_ROW; ++i) {
        double sum = pow(*(*(trainLabelList + i) + 0) - *(*(y + i) + 0), 2);
        meanSquareError[i][0] = sum / MSE_ROW;
    }
    return meanSquareError;
}
/**************************************************************************************************/

/**
 * 求row行1列的数组的平均数
 * @param trainLabelList
 * @param row
 * @return
 */
double ReduceMean(double (*meanSquareError)[1], int row){
    double sum = 0;
    for (int i = 0; i < row; ++i) {
        sum += *(*(meanSquareError + i) + 0);
    }
    return sum / row;
}
/**************************************************************************************************/

/**
 * 通过梯度下降算法，计算新的参数列表
 * @param parameterArray 旧的参数列表
 * @param learningRate 学习率
 * @return
 */
double *CalculateNewParameterByGradientDescentAlgorithm(double parameterArray[], int arrayLength, double learningRate){
    double *newParameterArray = (double *) malloc(sizeof(double) * arrayLength);
    for (int i = 0; i < arrayLength; ++i) {
        newParameterArray[i] = parameterArray[i] - learningRate * parameterArray[i];
    }
    return newParameterArray;
}
/**************************************************************************************************/

/**
 * 求损失值。通过求均方误差MSE的方法求损失值
 * @param count
 * @param w2ColumnNumber
 * @param trainLabelList
 * @param y
 * @return
 */
double CalculateLossValue(int count, int w2ColumnNumber, double (*trainLabelList)[1], double (*y)[w2ColumnNumber]) {
    MSE_ROW = count;
    MSE_COLUMN = w2ColumnNumber;
    double (*meanSquareError)[1] = MeanSquareError(trainLabelList, y);
    double mean = ReduceMean(meanSquareError, MSE_ROW);
    return mean;
}
/**************************************************************************************************/

/**
 * L2正则化（解决过拟合问题）
 * @param w1w2
 * @param rowNumber
 * @param l2RegularizationRate
 * @return
 */
double L2Regularization(double (*w1w2)[1], int rowNumber, double l2RegularizationRate) {
    double sum;
    for (int i = 0; i < rowNumber; i++) {
        sum += *(*(w1w2 + i) + 0) * *(*(w1w2 + i) + 0);
    }
    return l2RegularizationRate * sum;
}
/**************************************************************************************************/

/**
 * 训练迭代轮数，从0开始，最多为decaySteps
 */
int GLOBAL_STEP;

/**
 * 使用指数衰减法，计算学习率
 * @param learningRate
 * @param decayRate
 * @param globalStep
 * @param decaySteps
 * @return
 */
double CalculateLearningRateByExponentialDecay(double learningRate, double decayRate, int globalStep, int decaySteps) {
    return learningRate * pow(decayRate, globalStep / decaySteps);
}

